{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from sen12ms_dataLoader import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data():\n",
    "\n",
    "    print('Loading data...')\n",
    "\n",
    "    sen12ms = SEN12MSDataset(\"/data/PublicData/DF2020/trn/\")\n",
    "\n",
    "    IDs = np.load('/data/PublicData/DF2020/trn/clean.npy')\n",
    "\n",
    "    N = IDs.shape[0]\n",
    "\n",
    "#     idx=np.arange(N)\n",
    "#     np.random.shuffle(idx)\n",
    "\n",
    "    trn_ids=IDs[:int(0.2*N),:]\n",
    " \n",
    "    s1_trn_name=[]\n",
    "    s2_trn_name=[]\n",
    "    y_trn_name=[]\n",
    "\n",
    "\n",
    "    season_dict={1:Seasons.SPRING,2:Seasons.SUMMER,3:Seasons.FALL,4:Seasons.WINTER}\n",
    "\n",
    "    print('loading training files...')\n",
    "\n",
    "    for i in tqdm(range(trn_ids.shape[0])):\n",
    "        s1_name,s2_name,y_name=sen12ms.get_s1s2lc_triplet(season_dict[trn_ids[i,0]], trn_ids[i,1], trn_ids[i,2],\n",
    "                                                                               s1_bands=S1Bands.ALL,s2_bands=S2Bands.ALL, lc_bands=LCBands.ALL)\n",
    "        s1_trn_name.append(s1_name)\n",
    "        s2_trn_name.append(s2_name)\n",
    "        y_trn_name.append(y_name)\n",
    "\n",
    "    s1_trn_name = np.array(s1_trn_name)\n",
    "    s2_trn_name = np.array(s2_trn_name)\n",
    "    y_trn_name = np.array(y_trn_name)\n",
    "\n",
    "    return s1_trn_name,s2_trn_name,y_trn_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|█▋        | 3759/22335 [00:00<00:00, 37581.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading data...\n",
      "loading training files...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 22335/22335 [00:00<00:00, 41287.69it/s]\n"
     ]
    }
   ],
   "source": [
    "s1_trn_name,s2_trn_name,y_trn_name=load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "s1_data=np.zeros([s1_trn_name.shape[0],2,256,256],dtype='float16')\n",
    "s2_data=np.zeros([s1_trn_name.shape[0],13,256,256],dtype='float16')\n",
    "y_data=np.zeros([s1_trn_name.shape[0],1,256,256],dtype='float16')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import rasterio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 22335/22335 [03:19<00:00, 112.18it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(s1_data.shape[0])):\n",
    "    with rasterio.open(s1_trn_name[i]) as patch:\n",
    "        s1_data[i] = patch.read(list(range(1,3)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 22335/22335 [07:29<00:00, 49.67it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(s1_data.shape[0])):\n",
    "    with rasterio.open(s2_trn_name[i]) as patch:\n",
    "        s2_data[i] = patch.read(list(range(1,14)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_data=s2_data[:,[1,2,3,4,5,6,7,10,11,12],:,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_data=np.concatenate((s1_data,s2_data),axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "del s1_data,s2_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 22335/22335 [02:29<00:00, 149.72it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(s_data.shape[0])):\n",
    "    with rasterio.open(y_trn_name[i]) as patch:\n",
    "        y_data[i] = patch.read(list(range(1,2)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "data clean & label mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_data[np.isnan(s_data)] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 22335/22335 [52:32<00:00,  7.08it/s] \n"
     ]
    }
   ],
   "source": [
    "lab_dict={1:1,2:1,3:1,4:1,5:1,6:2,7:2,8:3,9:3,10:4,11:5,12:6,14:6,13:7,15:8,16:9,17:10}\n",
    "for i in tqdm(range(s_data.shape[0])):\n",
    "    tmp=y_data[i,:,:,:]\n",
    "    tmp=tmp.reshape(-1)\n",
    "    y = list(map(lambda x: lab_dict[x], tmp))\n",
    "    y=np.array(y)#list->array\n",
    "    y=y.reshape(1,256,256)\n",
    "    y_data[i,:,:,:]=y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_data=y_data-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float16)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(y_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_data=y_data.astype('uint8')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "filter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(22335, 12, 256, 256)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▍         | 1024/22335 [00:45<16:55, 20.98it/s]/data/di.wang/.conda/envs/pytorch/lib/python3.5/site-packages/ipykernel_launcher.py:14: RuntimeWarning: overflow encountered in add\n",
      "  \n",
      " 58%|█████▊    | 12843/22335 [09:31<07:45, 20.39it/s]/data/di.wang/.conda/envs/pytorch/lib/python3.5/site-packages/ipykernel_launcher.py:10: RuntimeWarning: divide by zero encountered in true_divide\n",
      "  # Remove the CWD from sys.path while we load stuff.\n",
      "/data/di.wang/.conda/envs/pytorch/lib/python3.5/site-packages/ipykernel_launcher.py:11: RuntimeWarning: invalid value encountered in true_divide\n",
      "  # This is added back by InteractiveShellApp.init_path()\n",
      "100%|██████████| 22335/22335 [16:37<00:00, 22.40it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(s_data.shape[0])):\n",
    "    x=s_data[i]\n",
    "    R = x[4, :, :]\n",
    "    G = x[3, :, :]\n",
    "    B = x[2, :, :]\n",
    "    Nir = x[8, :, :]  # TM4\n",
    "    Mir = x[-2, :, :]  # TM5\n",
    "    SWir = x[-1, :, :]  # TM7\n",
    "\n",
    "    MSI = SWir / Nir\n",
    "    NDWI = (G - Nir) / (G + Nir)\n",
    "    NDVI = (Nir - R) / (Nir + R)\n",
    "    NDBBI = (1.5 * SWir - (Nir + G) / 2.) / (1.5 * SWir + (Nir + G) / 2.)  # 归一化差值裸地与\n",
    "    BSI = ((Mir + R) - (Nir + B)) / ((Mir + R) + (Nir + B))  # 裸土指数\n",
    "    \n",
    "    y_clean=y_data[i,0,:,:].copy()\n",
    "    y=y_data[i,0,:,:].copy()\n",
    "    \n",
    "    # 修正不符合要求的森林类\n",
    "    y_clean[np.where((NDVI > 0.75) & (y != 0))] = 10\n",
    "    # 修正不符合要求的灌木类\n",
    "    y_clean[np.where((NDVI > 0.2) & (NDVI < 0.35) & (MSI > 1.5) & (y != 1))] = 10\n",
    "    # 修正不符合要求的草地类\n",
    "    y_clean[np.where((NDVI > 0.4) & (NDVI < 0.55) & (y != 3))] = 10\n",
    "    # 修正不符合要求的湿地类\n",
    "    y_clean[np.where((NDVI > 0.6) & (NDVI < 0.75) & (y != 4))] = 10\n",
    "    # 修正不符合要求的农田类\n",
    "    y_clean[np.where((NDVI > 0.2) & (NDVI < 0.35) & (MSI > 1) & (MSI < 1.5) & (y != 5))] = 10\n",
    "    # 修正不符合要求的建筑类\n",
    "    y_clean[np.where((NDVI > 0.2) & (NDVI < 0.35) & (MSI > 0.9) & (MSI < 1) & (y != 6))] = 10\n",
    "    # 修正不符合要求的裸地类\n",
    "    y_clean[np.where((NDVI > 0) & (NDVI < 0.15) & (y != 8))] = 10\n",
    "    # 修正裸土建筑错分到其他类\n",
    "    y_clean[np.where((BSI > -0.4) & (NDVI < 0.15) & (y != 6) & (y != 8))] = 10\n",
    "    # 修正不符合要求的水体类\n",
    "    y_clean[np.where((NDWI > 0.) & (y != 10))] = 10\n",
    "    \n",
    "    # 修正其他类错标到森林\n",
    "    y_clean[np.where((NDVI < 0.75) & (y == 0))] = 10\n",
    "\n",
    "    # shrubland\n",
    "\n",
    "    # 将灌木标签修正为草地\n",
    "    y_clean[np.where((NDVI > 0.4) & (NDVI < 0.55) & (y == 1))] = 3\n",
    "\n",
    "    # savanna\n",
    "\n",
    "    # 将热带草原标签修正为草地\n",
    "    y_clean[np.where((NDVI > 0.4) & (NDVI < 0.55) & (y == 2) & (np.sum(y == 9) < 2000))] = 3\n",
    "    # 将热带草原标签修正为湿地\n",
    "    y_clean[np.where((NDVI > 0.6) & (NDVI < 0.75) & (y == 2) & (np.sum(y == 9) > 10000))] = 4\n",
    "\n",
    "    # grassland\n",
    "    \n",
    "    # 将草地标签修正为湿地\n",
    "    y_clean[np.where((NDVI > 0.6) & (NDVI < 0.75) & (y == 3) & (np.sum(y==9)>10000))] = 4\n",
    "\n",
    "    # wetland\n",
    "    # 将湿地修正为森林\n",
    "    y_clean[np.where((NDVI > 0.75) & (y == 4))] = 0\n",
    "    \n",
    "    y_data[i,0,:,:]=y_clean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_data[y_data==2]=10\n",
    "y_data[y_data==7]=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0,  1,  3,  4,  5,  6,  8,  9, 10], dtype=uint8)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(y_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "data norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "#s1\n",
    "tmp1=s_data[:,:2,:,:]\n",
    "tmp1[tmp1<-25]=-25\n",
    "tmp1[tmp1>0]=0\n",
    "\n",
    "tmp1 = (tmp1 + 25) / 25 * 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.min(tmp1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_data[:,:2,:,:]=tmp1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "del tmp1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "#s2\n",
    "tmp2=s_data[:,2:,:,:]\n",
    "tmp2[tmp2>10000]=10000\n",
    "tmp2[tmp2<0]=0\n",
    "\n",
    "tmp2/=10000*1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_data[:,2:,:,:]=tmp2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.max(tmp2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "delete bg data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "X=s_data.transpose(0,2,3,1).reshape(-1,12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "y=y_data.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1463746560, 12)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1463746560,)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx=y!=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "X=X[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(624407841, 12)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "y=y[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(624407841,)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('/data/PublicData/DF2020/trn/transform/ML_trn.npy',X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('/data/PublicData/DF2020/trn/transform/ML_lab.npy',y)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
